{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0261b70c-32e2-4482-9f75-c0dfbf4f9285",
   "metadata": {},
   "source": [
    "### Few-shot learning testing\n",
    "\n",
    "Here we explore classifiers trained to decode the object shape with varying amounts of training data, from a range of layers within the model.\n",
    "\n",
    "Tested with tensorflow 2.11.0 and Python 3.10.9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bec94eeb-27ca-4bc5-af6f-bac6d96b34d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import zipfile\n",
    "import os\n",
    "import psutil\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.models import Model\n",
    "from PIL import Image, ImageOps\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import tensorflow.keras.backend as K\n",
    "import matplotlib.pyplot as plt\n",
    "from tensorflow.keras import Model, Sequential, metrics, optimizers, layers\n",
    "from tensorflow.python.framework.ops import disable_eager_execution\n",
    "from utils import load_tfds_dataset\n",
    "from generative_model import encoder_network_large, decoder_network_large, VAE, build_encoder_decoder_large\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.pipeline import make_pipeline\n",
    "\n",
    "tf.keras.utils.set_random_seed(123)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c3494d0-27da-4c1f-817d-05de7d11eaf1",
   "metadata": {},
   "source": [
    "#### Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "302f6cd7-db0b-431d-befd-05e4f88f7327",
   "metadata": {},
   "outputs": [],
   "source": [
    "K.set_image_data_format('channels_last')\n",
    "\n",
    "latent_dim = 20\n",
    "input_shape = (64, 64, 3)\n",
    "\n",
    "encoder, decoder = build_encoder_decoder_large(latent_dim=latent_dim)\n",
    "vae = VAE(encoder, decoder, 1)\n",
    "\n",
    "vae.encoder.load_weights(\"model_weights/shapes3d_encoder.h5\")\n",
    "vae.decoder.load_weights(\"model_weights/shapes3d_decoder.h5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b3e949da-3801-44ec-b6c9-a772c4ad44eb",
   "metadata": {},
   "source": [
    "#### Train classifier\n",
    "\n",
    "(Identical method to classifier used to measure decoding accuracy over time.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce9e87d4-7f35-4e6c-a41d-d193cb1402f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds, test_ds, train_labels, test_labels = load_tfds_dataset('shapes3d', labels=True, \n",
    "                                                                 key_dict= {'shapes3d': 'label_shape'})\n",
    "train_ds = train_ds / 255\n",
    "test_ds = test_ds / 255"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "776eb826-3f8c-48a0-b4de-b14566e51423",
   "metadata": {},
   "source": [
    "Get names of intermediate layers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d3e4b52-62c0-4e0c-8352-94a2b7bb717a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for layer in vae.encoder.layers:\n",
    "    print(layer.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ccb970-d267-441e-9193-74d17e893bb3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# define number of examples to use\n",
    "num_examples = list([10, 20, 50, 100, 200, 300, 400, 500, 1000])\n",
    "\n",
    "# place to store the scores\n",
    "scores = []\n",
    "\n",
    "# specify the layers to consider\n",
    "layers_to_consider = ['input_1', 'conv2d', 'conv2d_1', 'conv2d_2', 'conv2d_3', 'mean']  # fill this list with the names of the layers you want to consider\n",
    "\n",
    "# for each layer in the VAE encoder\n",
    "for layer in layers_to_consider:\n",
    "\n",
    "    # define a new model that outputs the intermediate representations\n",
    "    intermediate_layer_model = Model(inputs=vae.encoder.input, \n",
    "                                     outputs=vae.encoder.get_layer(layer).output)\n",
    "\n",
    "    # create a list to store the scores for this layer\n",
    "    layer_scores = []\n",
    "\n",
    "    # for each number of examples\n",
    "    for num in num_examples:\n",
    "        # get a subset of the training data\n",
    "        train_subset = train_ds[:num]\n",
    "        train_labels_subset = train_labels[:num]\n",
    "    \n",
    "        # skip if less than two classes are present in the subset\n",
    "        if len(np.unique(train_labels_subset)) < 2:\n",
    "            layer_scores.append(np.nan)\n",
    "            continue\n",
    "\n",
    "        # get the intermediate representations\n",
    "        intermediate_train = intermediate_layer_model.predict(train_subset)\n",
    "    \n",
    "        # flatten the intermediate representations\n",
    "        intermediate_train = np.reshape(intermediate_train, (intermediate_train.shape[0], -1))\n",
    "    \n",
    "        # train the classifier\n",
    "        clf = make_pipeline(StandardScaler(), SVC(probability=True))\n",
    "        clf.fit(intermediate_train, train_labels_subset)\n",
    "    \n",
    "        # get the intermediate representations for the test data\n",
    "        intermediate_test = intermediate_layer_model.predict(test_ds)\n",
    "    \n",
    "        # flatten the intermediate representations\n",
    "        intermediate_test = np.reshape(intermediate_test, (intermediate_test.shape[0], -1))\n",
    "    \n",
    "        # score the classifier\n",
    "        score = clf.score(intermediate_test, test_labels)\n",
    "    \n",
    "        # store the score\n",
    "        layer_scores.append(score)\n",
    "    \n",
    "    # store the scores for this layer\n",
    "    scores.append(layer_scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d54da8e4-e2a7-4c50-9fe0-1db2991f80c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "layer_name_dict = {'input_1': 'Input',\n",
    "                   'conv2d': '1st conv. layer',\n",
    "                   'conv2d_1': '2nd conv. layer',\n",
    "                   'conv2d_2': '3rd conv. layer',\n",
    "                   'conv2d_3': '4th conv. layer',\n",
    "                   'mean': 'Latent code',}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df55dfb8-ae83-457d-94c7-bfe886c56a2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the scores\n",
    "plt.figure(figsize=(10, 6))\n",
    "for i, layer_scores in enumerate(scores):\n",
    "    plt.plot(num_examples, layer_scores, label=layer_name_dict[layers_to_consider[i]])\n",
    "plt.xlabel('Number of training examples', fontsize=16)\n",
    "plt.xticks(fontsize=14)\n",
    "plt.yticks(fontsize=14)\n",
    "plt.ylabel('Accuracy', fontsize=16)\n",
    "plt.title('Classifier performance based on the number of training examples', fontsize=16)\n",
    "plt.legend(fontsize=14)\n",
    "plt.savefig('few_shot_testing.png', bbox_inches='tight', dpi=300)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
